package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.DCT
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.StringIndexerModel
import org.apache.spark.ml.feature.IndexToString
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.attribute.Attribute

object IndexToStringTest extends BaseTest {
    
    def apply(): BaseSpark = new IndexToStringTest()
    
}

class IndexToStringTest extends BaseSpark{
    
     override def getDataFrame(sparkSession: SparkSession = this.getSparkSession()): DataFrame = {
         sparkSession.createDataFrame(Seq(
              (0, "a"), (1, "b"), 
              (2, "c"), (3, "a"), 
              (4, "a"), (5, "c")
         ))
         .toDF("id", "category")
     }
    
    override def execute(dataFrame: DataFrame) = {
        //设置模型
        val indexed = new StringIndexer()
        .setInputCol("category")
        .setOutputCol("categoryIndex")
        .fit(dataFrame)
        .transform(dataFrame)
        //根据索引获取数据            
        val converter = new IndexToString()
        .setInputCol("categoryIndex")
        .setOutputCol("originalCategory")
        .setLabels(Array("a", "c", "b"))    //若指定映射表，这使用当前设置的，若没有设置，这在categoryIndex中的元数据信息中获取映射表
        //转换
        val converted = converter.transform(indexed)
        //show
        converted.show()
    }
    
}